# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any

from haystack.core.errors import DeserializationError, SerializationError
from haystack.core.serialization import generate_qualified_class_name, import_class_by_name
from haystack.utils import deserialize_callable, serialize_callable


def serialize_class_instance(obj: Any) -> dict[str, Any]:
    """
    Serializes an object that has a `to_dict` method into a dictionary.

    :param obj:
        The object to be serialized.
    :returns:
        A dictionary representation of the object.
    :raises SerializationError:
        If the object does not have a `to_dict` method.
    """
    if not hasattr(obj, "to_dict"):
        raise SerializationError(f"Object of class '{type(obj).__name__}' does not have a 'to_dict' method")

    output = obj.to_dict()
    return {"type": generate_qualified_class_name(type(obj)), "data": output}


def deserialize_class_instance(data: dict[str, Any]) -> Any:
    """
    Deserializes an object from a dictionary representation generated by `auto_serialize_class_instance`.

    :param data:
        The dictionary to deserialize from.
    :returns:
        The deserialized object.
    :raises DeserializationError:
        If the serialization data is malformed, the class type cannot be imported, or the
        class does not have a `from_dict` method.
    """
    if "type" not in data:
        raise DeserializationError("Missing 'type' in serialization data")
    if "data" not in data:
        raise DeserializationError("Missing 'data' in serialization data")

    try:
        obj_class = import_class_by_name(data["type"])
    except ImportError as e:
        raise DeserializationError(f"Class '{data['type']}' not correctly imported") from e

    if not hasattr(obj_class, "from_dict"):
        raise DeserializationError(f"Class '{data['type']}' does not have a 'from_dict' method")

    return obj_class.from_dict(data["data"])


def _serialize_value_with_schema(payload: Any) -> dict[str, Any]:
    """
    Serializes a value into a schema-aware format suitable for storage or transmission.

    The output format separates the schema information from the actual data, making it easier
    to deserialize complex nested structures correctly.

    The function handles:
    - Objects with to_dict() methods (e.g. dataclasses)
    - Objects with __dict__ attributes
    - Dictionaries
    - Lists, tuples, and sets. Lists with mixed types are not supported.
    - Primitive types (str, int, float, bool, None)

    :param payload: The value to serialize (can be any type)
    :returns: The serialized dict representation of the given value. Contains two keys:
        - "serialization_schema": Contains type information for each field.
        - "serialized_data": Contains the actual data in a simplified format.

    """
    # Handle dictionary case - iterate through fields
    if isinstance(payload, dict):
        schema: dict[str, Any] = {}
        data: dict[str, Any] = {}

        for field, val in payload.items():
            # Recursively serialize each field
            serialized_value = _serialize_value_with_schema(val)
            schema[field] = serialized_value["serialization_schema"]
            data[field] = serialized_value["serialized_data"]

        return {"serialization_schema": {"type": "object", "properties": schema}, "serialized_data": data}

    # Handle array case - iterate through elements
    elif isinstance(payload, (list, tuple, set)):
        # Convert to list for consistent handling
        pure_list = _convert_to_basic_types(list(payload))

        # Determine item type from first element (if any)
        if payload:
            first = next(iter(payload))
            item_schema = _serialize_value_with_schema(first)
            base_schema = {"type": "array", "items": item_schema["serialization_schema"]}
        else:
            base_schema = {"type": "array", "items": {}}

        # Add JSON Schema properties to infer sets and tuples
        if isinstance(payload, set):
            base_schema["uniqueItems"] = True
        elif isinstance(payload, tuple):
            base_schema["minItems"] = len(payload)
            base_schema["maxItems"] = len(payload)

        return {"serialization_schema": base_schema, "serialized_data": pure_list}

    # Handle Haystack style objects (e.g. dataclasses and Components)
    elif hasattr(payload, "to_dict") and callable(payload.to_dict):
        type_name = generate_qualified_class_name(type(payload))
        pure = _convert_to_basic_types(payload)
        schema = {"type": type_name}
        return {"serialization_schema": schema, "serialized_data": pure}

    # Handle callable functions serialization
    elif callable(payload) and not isinstance(payload, type):
        serialized = serialize_callable(payload)
        return {"serialization_schema": {"type": "typing.Callable"}, "serialized_data": serialized}

    # Handle arbitrary objects with __dict__
    elif hasattr(payload, "__dict__"):
        type_name = generate_qualified_class_name(type(payload))
        pure = _convert_to_basic_types(vars(payload))
        schema = {"type": type_name}
        return {"serialization_schema": schema, "serialized_data": pure}

    # Handle primitives
    else:
        prim_type = _primitive_schema_type(payload)
        schema = {"type": prim_type}
        return {"serialization_schema": schema, "serialized_data": payload}


def _primitive_schema_type(value: Any) -> str:
    """
    Helper function to determine the schema type for primitive values.
    """
    if value is None:
        return "null"
    if isinstance(value, bool):
        return "boolean"
    if isinstance(value, int):
        return "integer"
    if isinstance(value, float):
        return "number"
    if isinstance(value, str):
        return "string"
    return "string"  # fallback


def _convert_to_basic_types(value: Any) -> Any:
    """
    Helper function to recursively convert complex Python objects into their basic type equivalents.

    This helper function traverses through nested data structures and converts all complex
    objects (custom classes, dataclasses, etc.) into basic Python types (dict, list, str,
    int, float, bool, None) that can be easily serialized.

    The function handles:
    - Objects with to_dict() methods: converted using their to_dict implementation
    - Objects with __dict__ attribute: converted to plain dictionaries
    - Dictionaries: recursively converted values while preserving keys
    - Sequences (list, tuple, set): recursively converted while preserving type
    - Function objects: converted to None (functions cannot be serialized)
    - Primitive types: returned as-is

    """
    # dataclass‐style objects
    if hasattr(value, "to_dict") and callable(value.to_dict):
        return _convert_to_basic_types(value.to_dict())

    # Handle function objects - they cannot be serialized, so we return None
    if callable(value) and not isinstance(value, type):
        return None

    # arbitrary objects with __dict__
    if hasattr(value, "__dict__"):
        return {k: _convert_to_basic_types(v) for k, v in vars(value).items()}

    # dicts
    if isinstance(value, dict):
        return {k: _convert_to_basic_types(v) for k, v in value.items()}

    # sequences
    if isinstance(value, (list, tuple, set)):
        return [_convert_to_basic_types(v) for v in value]

    # primitive
    return value


def _deserialize_value_with_schema(serialized: dict[str, Any]) -> Any:  # pylint: disable=too-many-return-statements, # noqa: PLR0911, PLR0912
    """
    Deserializes a value with schema information back to its original form.

    Takes a dict of the form:
      {
         "serialization_schema": {"type": "integer"} or {"type": "object", "properties": {...}},
         "serialized_data": <the actual data>
      }

    :param serialized: The serialized dict with schema and data.
    :returns: The deserialized value in its original form.
    """

    if not serialized or "serialization_schema" not in serialized or "serialized_data" not in serialized:
        raise DeserializationError(
            f"Invalid format of passed serialized payload. Expected a dictionary with keys "
            f"'serialization_schema' and 'serialized_data'. Got: {serialized}"
        )
    schema = serialized["serialization_schema"]
    data = serialized["serialized_data"]

    schema_type = schema.get("type")

    if not schema_type:
        # for backward compatibility till Haystack 2.16 we use legacy implementation
        raise DeserializationError(
            "Missing 'type' key in 'serialization_schema'. This likely indicates that you're using a serialized "
            "State object created with a version of Haystack older than 2.15.0. "
            "Support for the old serialization format is removed in Haystack 2.16.0. "
            "Please upgrade to the new serialization format to ensure forward compatibility."
        )

    # Handle object case (dictionary with properties)
    if schema_type == "object":
        properties = schema.get("properties")
        if properties:
            result: dict[str, Any] = {}

            if isinstance(data, dict):
                for field, raw_value in data.items():
                    field_schema = properties.get(field)
                    if field_schema:
                        # Recursively deserialize each field - avoid creating temporary dict
                        result[field] = _deserialize_value_with_schema(
                            {"serialization_schema": field_schema, "serialized_data": raw_value}
                        )

            return result
        else:
            return _deserialize_value(data)

    # Handle array case
    elif schema_type == "array":
        # Cache frequently accessed schema properties
        item_schema = schema.get("items", {})
        item_type = item_schema.get("type", "any")
        is_set = schema.get("uniqueItems") is True
        is_tuple = schema.get("minItems") is not None and schema.get("maxItems") is not None

        # Handle nested objects/arrays first (most complex case)
        if item_type in ("object", "array"):
            return [
                _deserialize_value_with_schema({"serialization_schema": item_schema, "serialized_data": item})
                for item in data
            ]

        # Helper function to deserialize individual items
        def deserialize_item(item):
            if item_type == "any":
                return _deserialize_value(item)
            else:
                return _deserialize_value({"type": item_type, "data": item})

        # Handle different collection types
        if is_set:
            return {deserialize_item(item) for item in data}
        elif is_tuple:
            return tuple(deserialize_item(item) for item in data)
        else:
            return [deserialize_item(item) for item in data]

    # Handle primitive types
    elif schema_type in ("null", "boolean", "integer", "number", "string"):
        return data

    # Handle callable functions
    elif schema_type == "typing.Callable":
        return deserialize_callable(data)

    # Handle custom class types
    else:
        return _deserialize_value({"type": schema_type, "data": data})


def _deserialize_value(value: Any) -> Any:  # pylint: disable=too-many-return-statements # noqa: PLR0911
    """
    Helper function to deserialize values from their envelope format {"type": T, "data": D}.

    Handles four cases:
    - Typed envelopes: {"type": T, "data": D} where T determines deserialization method
    - Plain dicts: recursively deserialize values
    - Collections (list/tuple/set): recursively deserialize elements
    - Other values: return as-is

    :param value: The value to deserialize
    :returns: The deserialized value

    """
    # 1) Envelope case
    if isinstance(value, dict) and "type" in value and "data" in value:
        t = value["type"]
        payload = value["data"]

        # 1.a) Array
        if t == "array":
            return [_deserialize_value(child) for child in payload]

        # 1.b) Generic object/dict
        if t == "object":
            return {k: _deserialize_value(v) for k, v in payload.items()}

        # 1.c) Primitive
        if t in ("null", "boolean", "integer", "number", "string"):
            return payload

        # 1.d) Callable
        if t == "typing.Callable":
            return deserialize_callable(payload)

        # 1.e) Custom class
        cls = import_class_by_name(t)
        # first, recursively deserialize the inner payload
        deserialized_payload = {k: _deserialize_value(v) for k, v in payload.items()}
        # try from_dict
        if hasattr(cls, "from_dict") and callable(cls.from_dict):
            return cls.from_dict(deserialized_payload)
        # fallback: set attributes on a blank instance
        instance = cls.__new__(cls)
        for attr_name, attr_value in deserialized_payload.items():
            setattr(instance, attr_name, attr_value)
        return instance

    # 2) Plain dict (no envelope) → recurse
    if isinstance(value, dict):
        return {k: _deserialize_value(v) for k, v in value.items()}

    # 3) Collections → recurse
    if isinstance(value, (list, tuple, set)):
        return type(value)(_deserialize_value(v) for v in value)

    # 4) Fallback (shouldn't usually happen with our schema)
    return value
